#!/bin/bash


eval "$(conda shell.bash hook)"
conda activate jax
export XLA_PYTHON_CLIENT_PREALLOCATE=false
export XLA_PYTHON_CLIENT_ALLOCATOR=platform

# Changeable parameters
export GEMMA_MODEL_NAME_SHORT="gemma-2-2b"
export GEMMA_MODEL_NAME="google/$GEMMA_MODEL_NAME_SHORT"
export SAE_MODEL_PATH="$HOME/$GEMMA_MODEL_NAME_SHORT-sae/k5_final_sae_model.pkl"
export SAE_CODE_PATH="$HOME/$GEMMA_MODEL_NAME_SHORT-sae/k5_whole_sae_final_z.npy"
export MLP_MODEL_PATH="$HOME/dual-map/model/$GEMMA_MODEL_NAME_SHORT/dual_map_${GEMMA_MODEL_NAME_SHORT}.pt"

# Define arrays of values to test
hash_tables=(1024 2048)
hash_functions=(2 4)

# Loop through all combinations
for num_tables in "${hash_tables[@]}"; do
    for num_functions in "${hash_functions[@]}"; do
        echo "Testing with num_hash_tables=$num_tables and num_hash_functions=$num_functions"
        python $HOME/src/eval/sae-softmax/eval_sae_aware_lsh.py \
            --total_samples 1000 \
            --whitening=True \
            --lsh_only \
            --num_hash_tables $num_tables \
            --num_hash_functions $num_functions \
            --model_name $GEMMA_MODEL_NAME \
            --sae_model_path $SAE_MODEL_PATH \
            --sae_code_path $SAE_CODE_PATH \
            --mlp_model_path $MLP_MODEL_PATH
    done
done
